Add NVFP4 per-token quantization recipe#3045
Conversation
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.
* common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
* common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
to 2d quant of W).
* pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
grouped bulk binding and per-token GEMM entry; thin pybind layer.
* pytorch/custom_recipes/{gemm_nvfp4_per_token,
quantization_nvfp4_per_token_group}.py: Python wrappers.
* tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
cast tests + bf16-close GEMM tests.
* tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
Graphs columns, ratio against per-tensor RHT+SR baseline.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
6f17fe4 to
928ab1c
Compare
for more information, see https://pre-commit.ci
…uped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D). Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).
The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.
with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).
Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.
Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.
Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.
Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.
New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
- nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
(CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
- nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
matching the value cuBLASLt auto-folds via its amax slot.
Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.
Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
- fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
M,N,K = 256..1024.
- fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
(sanity check that the EVT tree and the baked constant are both correct).
Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.
bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
dispatch). Quant + GEMM inside the timing loop, N = K. Function
docstring carries an ASCII kernel-pipeline diagram for both paths
(per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
bwd never re-quantizes). pten side uses RHT-cols + SR grad
quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
carries an ASCII kernel-pipeline diagram (per-step launch budget:
per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
bf16 per-row*per-col post-scale, "Route 1") next to the existing
ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
baseline. Headline ratio lp/cf decides whether to dispatch
per-token through cuBLASLt + post_scale or fused EVT; current
data shows ct_fused wins or ties at every shape we care about.
test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
(catches breakage cheaper than the broader Layer 3 test).
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
for more information, see https://pre-commit.ci
| DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); | ||
| constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad | ||
|
|
||
| dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1); |
There was a problem hiding this comment.
maybe use DIVUP here to handle the remainder case?
There was a problem hiding this comment.
This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.
| // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. | ||
| // | ||
| // kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread | ||
| // FHT with random_sign_mask_t). Row direction never sees RHT. |
There was a problem hiding this comment.
typo: Row direction never sees RHT -> Row direction never uses RHT
| } | ||
| } | ||
| #else | ||
| NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); |
There was a problem hiding this comment.
For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?
There was a problem hiding this comment.
The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:
- The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
- The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad and wgrad, and let per-token opt into RHT via NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled (kernels unimplemented at this commit). Extends the per-token CUTLASS GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus dgrad/wgrad numerical tests and a fwd+bwd module smoke test. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Thread a Philox rng_state and a kWithSr template flag through the per-token encode kernel (rowwise + colwise) and the nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build the rng_state when stochastic rounding is requested. Add a per_token_sr recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN when averaged -- and RN-determinism / SR-nondeterminism) folded into test_nvfp4_per_token.py. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Wire with_sr + rng_state through the grouped per-token C-API and cast dispatch, implement the SR FP4 cast in the grouped kernel, and drop the "per-token does not support SR" guard. Also fix two comment typos (sees -> uses) in quantize_nvfp4_per_token.cu per review. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d), default off so the per-token path stays byte-equal. When enabled, only the forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar outer amax) re-dressed in per-token tensor layout: the scalar outer amax is broadcast across the per-row/col alpha vectors and the inner SF is the same 16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it unchanged with no kernel modification. Activation/gradient casts stay per-token 1D. Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a runnable single-GPU example so the recipe can be exercised end to end. - docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference. - docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the per-tensor disable flags. - docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the per-row/per-col outer-amax cast, its differences from the per-tensor default (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how to launch it with Megatron-Core. - recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d constructor kwargs and drop the stale "stochastic rounding unsupported" note. - pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling. - examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run + sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs BF16 with identical model/data/seed. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Greptile SummaryThis PR adds an NVFP4 per-token quantization recipe to Transformer Engine, replacing the per-tensor outer amax with per-row (length M) and per-col (length K) amax vectors. The implementation spans new CUDA cast kernels (K1 vector-amax + K2 FP4 encode), a fused CUTLASS EVT GEMM that folds the per-row/per-col outer-scale vectors directly into the bf16 epilogue, recipe classes (
Confidence Score: 3/5The PR adds a large new quantization path (13k+ lines) described as an accuracy-evaluation MVP. The new code is well-guarded for the intended end-to-end flow, but the weight-2D path in quantizer.cpp has a null-pointer in the amax restore that, while currently unreachable, lives adjacent to defensively-written null checks and should be fixed before the code grows more callers. The weight-2D amax restore in quantizer.cpp calls out.set_amax(rowwise_amax_ptr, ...) immediately after picking amax_ptr as the non-null fallback. If rowwise_amax_ptr is null, nvte_quantize_v2 would crash reading the amax. The path is shielded today by construction invariants, but the surrounding code already includes defensive nullptr checks that the critical line does not match. Combined with several explicitly documented MVP limitations, the change needs targeted fixes before it can be considered fully production-ready. transformer_engine/pytorch/csrc/quantizer.cpp (per-token weight-2D amax restore), transformer_engine/pytorch/cpp_extensions/gemm.py (alpha forwarding), transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py (0-token split handling) Important Files Changed
Sequence DiagramsequenceDiagram
participant Recipe as NVFP4PerTokenBlockScaling
participant Q as NVFP4Quantizer
participant K1 as nvte_nvfp4_per_token_quantize
participant Tensor as NVFP4Tensor
participant GEMM as _nvfp4_per_token_gemm
participant CGEMM as nvfp4_cutlass_per_token_gemm
Recipe->>Q: "per_token=True rowwise=True columnwise=True"
Q->>Q: amax_rowwise(M,) amax_columnwise(K,)
Q->>K1: quantize input
K1-->>Tensor: data(M,K/2) sf(M,K/16) amax(M,)
K1-->>Tensor: col_data(K,M/2) col_sf(K,M/16) col_amax(K,)
GEMM->>CGEMM: TN/NN/NT layout dispatch
CGEMM-->>GEMM: "D[i,j]=bf16(amax_a[i]*amax_b[j]*(A@B^T)[i,j])"
Reviews (1): Last reviewed commit: "Add docs and Megatron-Core example for t..." | Re-trigger Greptile |
| NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer for per-token weight-2D."); | ||
|
|
||
| // 1. Single scalar tensor amax -> amax[0] (mirror the per-tensor no-RHT path: | ||
| // treat the buffer as length 1 for the reduction, then fan out to both | ||
| // rowwise/columnwise amax[0]). | ||
| out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}); | ||
| NVTE_SCOPED_GIL_RELEASE( | ||
| { nvte_compute_amax_with_config(input.data(), out.data(), w2d_config, stream); }); | ||
| out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1}); | ||
| if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { | ||
| NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), | ||
| cudaMemcpyDeviceToDevice, stream)); | ||
| } | ||
| if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { | ||
| NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), |
There was a problem hiding this comment.
Latent null-pointer dereference in per-token weight-2D amax restore
After nvte_compute_amax_with_config writes the global amax into amax_ptr[0], the code restores the output tensor's amax field with out.set_amax(rowwise_amax_ptr, ...). If rowwise_amax_ptr is nullptr (i.e., the quantizer was constructed with rowwise=False), this sets the output's amax descriptor to a null pointer. The immediately following nvte_quantize_v2 then tries to read amax[0] to derive S_enc and will crash.
Currently this path is unreachable because per_token_weight_2d is only set for weight quantizers, and all weight quantizers in the recipe are constructed with rowwise=True, columnwise=True. However, the guard in step 3 (if (rowwise_amax_ptr != nullptr && w2d_rows > 1)) shows the author anticipated both pointers could be null, while the critical out.set_amax call on this line does not. Using amax_ptr (the non-null pointer already validated by the NVTE_CHECK above) would be safe in all configurations: out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1}).
| # Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row | ||
| # (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot, | ||
| # so this MUST short-circuit before the row-scaled-or-generic fork. | ||
| if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B): | ||
| if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)): | ||
| raise NotImplementedError( | ||
| "NVFP4 per-token GEMM requires both A and B to be per-token tensors. " | ||
| "Mixing per-token + prod NVFP4 in one GEMM is not supported." | ||
| ) | ||
| out = _nvfp4_per_token_gemm( | ||
| A, | ||
| B, | ||
| transa=transa, | ||
| transb=transb, | ||
| out=out, | ||
| out_dtype=out_dtype, | ||
| bias=bias, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| gelu=gelu, | ||
| quantization_params=quantization_params, | ||
| ub=ub, | ||
| extra_output=extra_output, | ||
| ) |
There was a problem hiding this comment.
alpha scalar silently ignored for per-token GEMM
general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.
| for i, M_i in enumerate(split_sections): | ||
| if M_i <= 0: | ||
| raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") | ||
| if M_i % _PER_TOKEN_TILE != 0: | ||
| raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") |
There was a problem hiding this comment.
Public grouped-quantize API unconditionally rejects 0-token splits
split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.
Changes
Ongoing work
The per-token recipe currently targets accuracy evaluation, not optimized production deployment:
Type of change
Checklist: